{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build a ShoeBot Sales Agent using LangGraph and Graphiti\n",
    "\n",
    "The following example demonstrates building an agent using LangGraph. Graphiti is used to personalize agent responses based on information learned from prior conversations. Additionally, a database of products is loaded into the Graphiti graph, enabling the agent to speak to these products.\n",
    "\n",
    "The agent implements:\n",
    "- persistence of new chat turns to Graphiti and recall of relevant Facts using the most recent message.\n",
    "- a tool for querying Graphiti for shoe information\n",
    "- an in-memory MemorySaver to maintain agent state.\n",
    "\n",
    "## Install dependencies\n",
    "```shell\n",
    "pip install graphiti-core langchain-openai langgraph ipywidgets\n",
    "```\n",
    "\n",
    "Ensure that you've followed the Graphiti installation instructions. In particular, installation of `neo4j`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "import json\n",
    "import logging\n",
    "import os\n",
    "import sys\n",
    "import uuid\n",
    "from contextlib import suppress\n",
    "from datetime import datetime, timezone\n",
    "from pathlib import Path\n",
    "from typing import Annotated\n",
    "\n",
    "import ipywidgets as widgets\n",
    "from dotenv import load_dotenv\n",
    "from IPython.display import Image, display\n",
    "from typing_extensions import TypedDict\n",
    "\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_logging():\n",
    "    logger = logging.getLogger()\n",
    "    logger.setLevel(logging.ERROR)\n",
    "    console_handler = logging.StreamHandler(sys.stdout)\n",
    "    console_handler.setLevel(logging.INFO)\n",
    "    formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')\n",
    "    console_handler.setFormatter(formatter)\n",
    "    logger.addHandler(console_handler)\n",
    "    return logger\n",
    "\n",
    "\n",
    "logger = setup_logging()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LangSmith integration (Optional)\n",
    "\n",
    "If you'd like to trace your agent using LangSmith, ensure that you have a `LANGSMITH_API_KEY` set in your environment.\n",
    "\n",
    "Then set `os.environ['LANGCHAIN_TRACING_V2'] = 'false'` to `true`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['LANGCHAIN_TRACING_V2'] = 'false'\n",
    "os.environ['LANGCHAIN_PROJECT'] = 'Graphiti LangGraph Tutorial'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configure Graphiti\n",
    "\n",
    "Ensure that you have `neo4j` running and a database created. Ensure that you've configured the following in your environment.\n",
    "\n",
    "```bash\n",
    "NEO4J_URI=\n",
    "NEO4J_USER=\n",
    "NEO4J_PASSWORD=\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configure Graphiti\n",
    "\n",
    "from graphiti_core import Graphiti\n",
    "from graphiti_core.edges import EntityEdge\n",
    "from graphiti_core.nodes import EpisodeType\n",
    "from graphiti_core.utils.maintenance.graph_data_operations import clear_data\n",
    "\n",
    "neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')\n",
    "neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')\n",
    "neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')\n",
    "\n",
    "client = Graphiti(\n",
    "    neo4j_uri,\n",
    "    neo4j_user,\n",
    "    neo4j_password,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating a database schema \n",
    "\n",
    "The following is only required for the first run of this notebook or when you'd like to start your database over.\n",
    "\n",
    "**IMPORTANT**: `clear_data` is destructive and will wipe your entire database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: This will clear the database\n",
    "await clear_data(client.driver)\n",
    "await client.build_indices_and_constraints()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Shoe Data into the Graph\n",
    "\n",
    "Load several shoe and related products into the Graphiti. This may take a while.\n",
    "\n",
    "\n",
    "**IMPORTANT**: This only needs to be done once. If you run `clear_data` you'll need to rerun this step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def ingest_products_data(client: Graphiti):\n",
    "    script_dir = Path.cwd().parent\n",
    "    json_file_path = script_dir / 'data' / 'manybirds_products.json'\n",
    "\n",
    "    with open(json_file_path) as file:\n",
    "        products = json.load(file)['products']\n",
    "\n",
    "    for i, product in enumerate(products):\n",
    "        await client.add_episode(\n",
    "            name=product.get('title', f'Product {i}'),\n",
    "            episode_body=str({k: v for k, v in product.items() if k != 'images'}),\n",
    "            source_description='ManyBirds products',\n",
    "            source=EpisodeType.json,\n",
    "            reference_time=datetime.now(timezone.utc),\n",
    "        )\n",
    "\n",
    "\n",
    "await ingest_products_data(client)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a user node in the Graphiti graph\n",
    "\n",
    "In your own app, this step could be done later once the user has identified themselves and made their sales intent known. We do this here so we can configure the agent with the user's `node_uuid`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_EPISODE_MENTIONS\n",
    "\n",
    "user_name = 'jess'\n",
    "\n",
    "await client.add_episode(\n",
    "    name='User Creation',\n",
    "    episode_body=(f'{user_name} is interested in buying a pair of shoes'),\n",
    "    source=EpisodeType.text,\n",
    "    reference_time=datetime.now(timezone.utc),\n",
    "    source_description='SalesBot',\n",
    ")\n",
    "\n",
    "# let's get Jess's node uuid\n",
    "nl = await client._search(user_name, NODE_HYBRID_SEARCH_EPISODE_MENTIONS)\n",
    "\n",
    "user_node_uuid = nl.nodes[0].uuid\n",
    "\n",
    "# and the ManyBirds node uuid\n",
    "nl = await client._search('ManyBirds', NODE_HYBRID_SEARCH_EPISODE_MENTIONS)\n",
    "manybirds_node_uuid = nl.nodes[0].uuid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def edges_to_facts_string(entities: list[EntityEdge]):\n",
    "    return '-' + '\\n- '.join([edge.fact for edge in entities])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import AIMessage, SystemMessage\n",
    "from langchain_core.tools import tool\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from langgraph.graph import END, START, StateGraph, add_messages\n",
    "from langgraph.prebuilt import ToolNode"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `get_shoe_data` Tool\n",
    "\n",
    "The agent will use this to search the Graphiti graph for information about shoes. We center the search on the `manybirds_node_uuid` to ensure we rank shoe-related data over user data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tool\n",
    "async def get_shoe_data(query: str) -> str:\n",
    "    \"\"\"Search the graphiti graph for information about shoes\"\"\"\n",
    "    edge_results = await client.search(\n",
    "        query,\n",
    "        center_node_uuid=manybirds_node_uuid,\n",
    "        num_results=10,\n",
    "    )\n",
    "    return edges_to_facts_string(edge_results)\n",
    "\n",
    "\n",
    "tools = [get_shoe_data]\n",
    "tool_node = ToolNode(tools)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = ChatOpenAI(model='gpt-4.1-mini', temperature=0).bind_tools(tools)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the tool node\n",
    "await tool_node.ainvoke({'messages': [await llm.ainvoke('wool shoes')]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chatbot Function Explanation\n",
    "\n",
    "The chatbot uses Graphiti to provide context-aware responses in a shoe sales scenario. Here's how it works:\n",
    "\n",
    "1. **Context Retrieval**: It searches the Graphiti graph for relevant information based on the latest message, using the user's node as the center point. This ensures that user-related facts are ranked higher than other information in the graph.\n",
    "\n",
    "2. **System Message**: It constructs a system message incorporating facts from Graphiti, setting the context for the AI's response.\n",
    "\n",
    "3. **Knowledge Persistence**: After generating a response, it asynchronously adds the interaction to the Graphiti graph, allowing future queries to reference this conversation.\n",
    "\n",
    "This approach enables the chatbot to maintain context across interactions and provide personalized responses based on the user's history and preferences stored in the Graphiti graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class State(TypedDict):\n",
    "    messages: Annotated[list, add_messages]\n",
    "    user_name: str\n",
    "    user_node_uuid: str\n",
    "\n",
    "\n",
    "async def chatbot(state: State):\n",
    "    facts_string = None\n",
    "    if len(state['messages']) > 0:\n",
    "        last_message = state['messages'][-1]\n",
    "        graphiti_query = f'{\"SalesBot\" if isinstance(last_message, AIMessage) else state[\"user_name\"]}: {last_message.content}'\n",
    "        # search graphiti using Jess's node uuid as the center node\n",
    "        # graph edges (facts) further from the Jess node will be ranked lower\n",
    "        edge_results = await client.search(\n",
    "            graphiti_query, center_node_uuid=state['user_node_uuid'], num_results=5\n",
    "        )\n",
    "        facts_string = edges_to_facts_string(edge_results)\n",
    "\n",
    "    system_message = SystemMessage(\n",
    "        content=f\"\"\"You are a skillfull shoe salesperson working for ManyBirds. Review information about the user and their prior conversation below and respond accordingly.\n",
    "        Keep responses short and concise. And remember, always be selling (and helpful!)\n",
    "\n",
    "        Things you'll need to know about the user in order to close a sale:\n",
    "        - the user's shoe size\n",
    "        - any other shoe needs? maybe for wide feet?\n",
    "        - the user's preferred colors and styles\n",
    "        - their budget\n",
    "\n",
    "        Ensure that you ask the user for the above if you don't already know.\n",
    "\n",
    "        Facts about the user and their conversation:\n",
    "        {facts_string or 'No facts about the user and their conversation'}\"\"\"\n",
    "    )\n",
    "\n",
    "    messages = [system_message] + state['messages']\n",
    "\n",
    "    response = await llm.ainvoke(messages)\n",
    "\n",
    "    # add the response to the graphiti graph.\n",
    "    # this will allow us to use the graphiti search later in the conversation\n",
    "    # we're doing async here to avoid blocking the graph execution\n",
    "    asyncio.create_task(\n",
    "        client.add_episode(\n",
    "            name='Chatbot Response',\n",
    "            episode_body=f'{state[\"user_name\"]}: {state[\"messages\"][-1]}\\nSalesBot: {response.content}',\n",
    "            source=EpisodeType.message,\n",
    "            reference_time=datetime.now(timezone.utc),\n",
    "            source_description='Chatbot',\n",
    "        )\n",
    "    )\n",
    "\n",
    "    return {'messages': [response]}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up the Agent\n",
    "\n",
    "This section sets up the Agent's LangGraph graph:\n",
    "\n",
    "1. **Graph Structure**: It defines a graph with nodes for the agent (chatbot) and tools, connected in a loop.\n",
    "\n",
    "2. **Conditional Logic**: The `should_continue` function determines whether to end the graph execution or continue to the tools node based on the presence of tool calls.\n",
    "\n",
    "3. **Memory Management**: It uses a MemorySaver to maintain conversation state across turns. This is in addition to using Graphiti for facts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_builder = StateGraph(State)\n",
    "\n",
    "memory = MemorySaver()\n",
    "\n",
    "\n",
    "# Define the function that determines whether to continue or not\n",
    "async def should_continue(state, config):\n",
    "    messages = state['messages']\n",
    "    last_message = messages[-1]\n",
    "    # If there is no function call, then we finish\n",
    "    if not last_message.tool_calls:\n",
    "        return 'end'\n",
    "    # Otherwise if there is, we continue\n",
    "    else:\n",
    "        return 'continue'\n",
    "\n",
    "\n",
    "graph_builder.add_node('agent', chatbot)\n",
    "graph_builder.add_node('tools', tool_node)\n",
    "\n",
    "graph_builder.add_edge(START, 'agent')\n",
    "graph_builder.add_conditional_edges('agent', should_continue, {'continue': 'tools', 'end': END})\n",
    "graph_builder.add_edge('tools', 'agent')\n",
    "\n",
    "graph = graph_builder.compile(checkpointer=memory)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "Our LangGraph agent graph is illustrated below."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with suppress(Exception):\n",
    "    display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running the Agent\n",
    "\n",
    "Let's test the agent with a single call"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "await graph.ainvoke(\n",
    "    {\n",
    "        'messages': [\n",
    "            {\n",
    "                'role': 'user',\n",
    "                'content': 'What sizes do the TinyBirds Wool Runners in Natural Black come in?',\n",
    "            }\n",
    "        ],\n",
    "        'user_name': user_name,\n",
    "        'user_node_uuid': user_node_uuid,\n",
    "    },\n",
    "    config={'configurable': {'thread_id': uuid.uuid4().hex}},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Viewing the Graph\n",
    "\n",
    "At this stage, the graph would look something like this. The `jess` node is `INTERESTED_IN` the `TinyBirds Wool Runner` node. The image below was generated using Neo4j Desktop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Image(filename='tinybirds-jess.png', width=850))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running the Agent interactively\n",
    "\n",
    "The following code will run the agent in an event loop. Just enter a message into the box and click submit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conversation_output = widgets.Output()\n",
    "config = {'configurable': {'thread_id': uuid.uuid4().hex}}\n",
    "user_state = {'user_name': user_name, 'user_node_uuid': user_node_uuid}\n",
    "\n",
    "\n",
    "async def process_input(user_state: State, user_input: str):\n",
    "    conversation_output.append_stdout(f'\\nUser: {user_input}\\n')\n",
    "    conversation_output.append_stdout('\\nAssistant: ')\n",
    "\n",
    "    graph_state = {\n",
    "        'messages': [{'role': 'user', 'content': user_input}],\n",
    "        'user_name': user_state['user_name'],\n",
    "        'user_node_uuid': user_state['user_node_uuid'],\n",
    "    }\n",
    "\n",
    "    try:\n",
    "        async for event in graph.astream(\n",
    "            graph_state,\n",
    "            config=config,\n",
    "        ):\n",
    "            for value in event.values():\n",
    "                if 'messages' in value:\n",
    "                    last_message = value['messages'][-1]\n",
    "                    if isinstance(last_message, AIMessage) and isinstance(\n",
    "                        last_message.content, str\n",
    "                    ):\n",
    "                        conversation_output.append_stdout(last_message.content)\n",
    "    except Exception as e:\n",
    "        conversation_output.append_stdout(f'Error: {e}')\n",
    "\n",
    "\n",
    "def on_submit(b):\n",
    "    user_input = input_box.value\n",
    "    input_box.value = ''\n",
    "    asyncio.create_task(process_input(user_state, user_input))\n",
    "\n",
    "\n",
    "input_box = widgets.Text(placeholder='Type your message here...')\n",
    "submit_button = widgets.Button(description='Send')\n",
    "submit_button.on_click(on_submit)\n",
    "\n",
    "conversation_output.append_stdout('Asssistant: Hello, how can I help you find shoes today?')\n",
    "\n",
    "display(widgets.VBox([input_box, submit_button, conversation_output]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
